{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Automated Segmentation of Remote Sensing Imagery with SAM 3\n",
    "\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/segment-geospatial/blob/main/docs/examples/sam3_automated_segmentation.ipynb)\n",
    "\n",
    "In this notebook, we demonstrate the automated segmentation of remote sensing imagery using SAM 3. The process begins with image captioning, which automatically identifies key features within the image. These features can then be utilized as text prompts for SAM 3, enabling precise segmentation.\n",
    "\n",
    "## Installation\n",
    "\n",
    "First, make sure you have the required dependencies installed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install \"segment-geospatial[samgeo3]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Libraries\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import leafmap\n",
    "from samgeo import SamGeo3, download_file\n",
    "from samgeo.caption import ImageCaptioner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download Sample Data\n",
    "\n",
    "Let's download a sample satellite image covering the University of California, Berkeley, for testing:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/uc_berkeley.tif\"\n",
    "image_path = download_file(url)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = leafmap.Map()\n",
    "m.add_raster(image_path, layer_name=\"Satellite image\")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Image captioning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "captioner = ImageCaptioner(\n",
    "    blip_model_name=\"Salesforce/blip-image-captioning-base\",\n",
    "    spacy_model_name=\"en_core_web_sm\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "caption, features = captioner.analyze(image_path)\n",
    "print(f\"Caption: {caption}\")\n",
    "print(f\"Features: {features}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Request access to SAM3\n",
    "\n",
    "To use SAM3, you need to request access by filling out this form on Hugging Face: https://huggingface.co/facebook/sam3\n",
    "\n",
    "Once you have access, uncomment the following code block and run it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from huggingface_hub import login\n",
    "# login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize SAM3\n",
    "\n",
    "When initializing SAM3, you can choose the backend from \"meta\", or \"transformers\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam3 = SamGeo3(backend=\"meta\", device=None, checkpoint_path=None, load_from_HF=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set the image\n",
    "\n",
    "You can set the image by either passing the image path or the image URL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam3.set_image(image_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate masks with text prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature = features[0]\n",
    "sam3.generate_masks(prompt=feature)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam3.show_anns()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![annotation](https://github.com/user-attachments/assets/64323223-35a2-4e03-9cee-1b60fa0c12af)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam3.show_masks()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save Masks\n",
    "\n",
    "Save the generated masks to a file. If the input is a GeoTIFF, the output will be a GeoTIFF with the same georeferencing. Otherwise, it will be saved as PNG."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save masks with unique values for each object\n",
    "# Since uc_berkeley.tif is a GeoTIFF, the output will also be a GeoTIFF\n",
    "sam3.save_masks(output=f\"{feature}_masks.tif\", unique=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save as binary mask (all foreground pixels are 255)\n",
    "sam3.save_masks(output=f\"{feature}_masks_binary.tif\", unique=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save Masks with Confidence Scores\n",
    "\n",
    "You can also save the confidence scores for each mask. The scores indicate the model's confidence for each predicted mask."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save masks and confidence scores\n",
    "# Each pixel in the scores image will have the confidence value of its mask\n",
    "sam3.save_masks(\n",
    "    output=f\"{feature}_masks_with_scores.tif\",\n",
    "    save_scores=f\"{feature}_scores.tif\",\n",
    "    unique=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sam3.show_masks(cmap=\"coolwarm\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![scores](https://github.com/user-attachments/assets/23ec9b07-0de9-4f72-81b2-83a3c499e94e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize Confidence Scores\n",
    "\n",
    "Let's visualize the confidence scores to see which masks have higher confidence:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.add_raster(f\"{feature}_masks.tif\", layer_name=f\"{feature} masks\", visible=False)\n",
    "m.add_raster(\n",
    "    f\"{feature}_scores.tif\",\n",
    "    layer_name=f\"{feature} scores\",\n",
    "    cmap=\"coolwarm\",\n",
    "    opacity=0.8,\n",
    "    nodata=0,\n",
    ")\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![map](https://github.com/user-attachments/assets/fa21320d-b4f3-48f9-a3f2-828f4ed1c567)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "geo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
